Synthetic models for posterior distributions¶

Marco Raveri (marco.raveri@unige.it), Cyrille Doux (doux@lpsc.in2p3.fr), Shivam Pandey (shivampcosmo@gmail.com)

In this notebook we show how to build normalizing flow syntetic models for posterior distributions, as in Raveri, Doux and Pandey (2024), arXiv:XXXX.XXXX.

Notebook setup:¶

In [1]:
# Show plots inline, and load main getdist plot module and samples class
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

# import libraries:
import sys, os
sys.path.insert(0,os.path.realpath(os.path.join(os.getcwd(),'../..')))
from getdist import plots, MCSamples
from getdist.gaussian_mixtures import GaussianND
import getdist
getdist.chains.print_load_details = False
import scipy
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# tensorflow imports:
import tensorflow as tf
import tensorflow_probability as tfp

# import the tensiometer tools that we need:
import tensiometer
from tensiometer import utilities
from tensiometer.synthetic_probability import synthetic_probability as sp

# getdist settings to ensure consistency of plots:
getdist_settings = {'ignore_rows': 0.0, 
                    'smooth_scale_2D': 0.3,
                    'smooth_scale_1D': 0.3,
                    }    
2024-09-14 13:13:54.895667: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

We start by building a random Gaussian mixture that we are going to use for tests:

In [2]:
# define the parameters of the problem:
dim = 6
num_gaussians = 3
num_samples = 10000

# we seed the random number generator to get reproducible results:
seed = 100
np.random.seed(seed)
# we define the range for the means and covariances:
mean_range = (-0.5, 0.5)
cov_scale = 0.4**2
# means and covs:
means = np.random.uniform(mean_range[0], mean_range[1], num_gaussians*dim).reshape(num_gaussians, dim)
weights = np.random.rand(num_gaussians)
weights = weights / np.sum(weights)
covs = [cov_scale*utilities.vector_to_PDM(np.random.rand(int(dim*(dim+1)/2))) for _ in range(num_gaussians)]

# cast to required precision:
means = means.astype(np.float32)
weights = weights.astype(np.float32)
covs = [cov.astype(np.float32) for cov in covs]

# initialize distribution:
distribution = tfp.distributions.Mixture(
    cat=tfp.distributions.Categorical(probs=weights),
    components=[
        tfp.distributions.MultivariateNormalTriL(loc=_m, scale_tril=tf.linalg.cholesky(_c))
        for _m, _c in zip(means, covs)
    ], name='Mixture')

# sample the distribution:
samples = distribution.sample(num_samples).numpy()
# calculate log posteriors:
logP = distribution.log_prob(samples).numpy()

# create MCSamples from the samples:
chain = MCSamples(samples=samples, 
                    settings=getdist_settings,
                    loglikes=-logP,
                    name_tag='Mixture',
                    )

# we make a sanity check plot:
g = plots.get_subplot_plotter()
g.triangle_plot(chain, filled=True)
    
No description has been provided for this image

Base example:¶

We train a normalizing flow on samples of a given distribution.

We initialize and train the normalizing flow on samples of the distribution we have just defined:

In [3]:
kwargs = {
          'feedback': 2,
          'plot_every': 1000,
          'pop_size': 1,
          #'cache_dir': 'test',  # set this to a directory to cache the results
          #'root_name': 'test',  # sets the name of the flow for the cache files
        }

flow = sp.flow_from_chain(chain,  # parameter difference chain
                          **kwargs)
* Initializing samples
    - flow name: Mixture_flow
    - precision: <dtype: 'float32'>
    - flow parameters and ranges:
      param1 : [-1.15383, 1.2526]
      param2 : [-1.44704, 1.13465]
      param3 : [-1.38866, 0.841306]
      param4 : [-0.921243, 1.42548]
      param5 : [-1.98012, 1.3537]
      param6 : [-1.66968, 0.934447]
    - periodic parameters: []
    - time taken: 0.0008 seconds
* Initializing fixed bijector
    - using prior bijector: ranges
    - rescaling samples
    - time taken: 0.1862 seconds
* Initializing trainable bijector
    Building Autoregressive Flow
    - # parameters          : 6
    - periodic parameters   : None
    - # transformations     : 8
    - hidden_units          : [12, 12]
    - transformation_type   : affine
    - autoregressive_type   : masked
    - permutations          : True
    - scale_roto_shift      : False
    - activation            : <function asinh at 0x1b9e07d30>
    - time taken: 1.3858 seconds
* Initializing training dataset
    - 9000/1000 training/test samples and uniform weights
    - time taken: 1.3354 seconds
* Initializing transformed distribution
    - time taken: 0.0074 seconds
* Initializing loss function
    - using standard loss function
    - time taken: 0.0000 seconds
* Initializing training model
    - Compiling model
    - time taken: 0.0190 seconds
    - trainable parameters : 3168
    - maximum learning rate: 0.001
    - minimum learning rate: 1e-06
    - time taken: 1.6291 seconds
* Training
    - Compiling model
    - time taken: 0.0078 seconds
Epoch 1/100
20/20 - 7s - loss: 8.5123 - val_loss: 8.5250 - lr: 0.0010 - 7s/epoch - 372ms/step
Epoch 2/100
20/20 - 0s - loss: 8.4742 - val_loss: 8.4944 - lr: 0.0010 - 307ms/epoch - 15ms/step
Epoch 3/100
20/20 - 0s - loss: 8.4373 - val_loss: 8.4559 - lr: 0.0010 - 295ms/epoch - 15ms/step
Epoch 4/100
20/20 - 0s - loss: 8.3955 - val_loss: 8.4203 - lr: 0.0010 - 243ms/epoch - 12ms/step
Epoch 5/100
20/20 - 0s - loss: 8.3583 - val_loss: 8.3898 - lr: 0.0010 - 251ms/epoch - 13ms/step
Epoch 6/100
20/20 - 0s - loss: 8.3335 - val_loss: 8.3726 - lr: 0.0010 - 336ms/epoch - 17ms/step
Epoch 7/100
20/20 - 0s - loss: 8.3173 - val_loss: 8.3639 - lr: 0.0010 - 407ms/epoch - 20ms/step
Epoch 8/100
20/20 - 0s - loss: 8.3080 - val_loss: 8.3543 - lr: 0.0010 - 363ms/epoch - 18ms/step
Epoch 9/100
20/20 - 0s - loss: 8.3037 - val_loss: 8.3454 - lr: 0.0010 - 386ms/epoch - 19ms/step
Epoch 10/100
20/20 - 0s - loss: 8.2988 - val_loss: 8.3495 - lr: 0.0010 - 319ms/epoch - 16ms/step
Epoch 11/100
20/20 - 0s - loss: 8.2924 - val_loss: 8.3365 - lr: 0.0010 - 396ms/epoch - 20ms/step
Epoch 12/100
20/20 - 0s - loss: 8.2864 - val_loss: 8.3278 - lr: 0.0010 - 339ms/epoch - 17ms/step
Epoch 13/100
20/20 - 0s - loss: 8.2795 - val_loss: 8.3184 - lr: 0.0010 - 252ms/epoch - 13ms/step
Epoch 14/100
20/20 - 0s - loss: 8.2714 - val_loss: 8.3101 - lr: 0.0010 - 237ms/epoch - 12ms/step
Epoch 15/100
20/20 - 0s - loss: 8.2608 - val_loss: 8.2947 - lr: 0.0010 - 232ms/epoch - 12ms/step
Epoch 16/100
20/20 - 0s - loss: 8.2485 - val_loss: 8.2714 - lr: 0.0010 - 238ms/epoch - 12ms/step
Epoch 17/100
20/20 - 0s - loss: 8.2304 - val_loss: 8.2496 - lr: 0.0010 - 281ms/epoch - 14ms/step
Epoch 18/100
20/20 - 0s - loss: 8.2064 - val_loss: 8.2218 - lr: 0.0010 - 255ms/epoch - 13ms/step
Epoch 19/100
20/20 - 0s - loss: 8.1793 - val_loss: 8.1796 - lr: 0.0010 - 354ms/epoch - 18ms/step
Epoch 20/100
20/20 - 0s - loss: 8.1464 - val_loss: 8.1397 - lr: 0.0010 - 457ms/epoch - 23ms/step
Epoch 21/100
20/20 - 0s - loss: 8.1103 - val_loss: 8.0871 - lr: 0.0010 - 361ms/epoch - 18ms/step
Epoch 22/100
20/20 - 0s - loss: 8.0677 - val_loss: 8.0290 - lr: 0.0010 - 300ms/epoch - 15ms/step
Epoch 23/100
20/20 - 0s - loss: 8.0113 - val_loss: 7.9604 - lr: 0.0010 - 217ms/epoch - 11ms/step
Epoch 24/100
20/20 - 0s - loss: 7.9549 - val_loss: 7.8946 - lr: 0.0010 - 248ms/epoch - 12ms/step
Epoch 25/100
20/20 - 0s - loss: 7.8993 - val_loss: 7.8299 - lr: 0.0010 - 233ms/epoch - 12ms/step
Epoch 26/100
20/20 - 0s - loss: 7.8506 - val_loss: 7.7905 - lr: 0.0010 - 245ms/epoch - 12ms/step
Epoch 27/100
20/20 - 0s - loss: 7.8143 - val_loss: 7.7475 - lr: 0.0010 - 232ms/epoch - 12ms/step
Epoch 28/100
20/20 - 0s - loss: 7.7840 - val_loss: 7.7150 - lr: 0.0010 - 318ms/epoch - 16ms/step
Epoch 29/100
20/20 - 0s - loss: 7.7567 - val_loss: 7.7113 - lr: 0.0010 - 286ms/epoch - 14ms/step
Epoch 30/100
20/20 - 0s - loss: 7.7313 - val_loss: 7.6872 - lr: 0.0010 - 231ms/epoch - 12ms/step
Epoch 31/100
20/20 - 0s - loss: 7.7210 - val_loss: 7.6730 - lr: 0.0010 - 227ms/epoch - 11ms/step
Epoch 32/100
20/20 - 0s - loss: 7.7022 - val_loss: 7.6678 - lr: 0.0010 - 245ms/epoch - 12ms/step
Epoch 33/100
20/20 - 0s - loss: 7.6828 - val_loss: 7.6591 - lr: 0.0010 - 228ms/epoch - 11ms/step
Epoch 34/100
20/20 - 0s - loss: 7.6691 - val_loss: 7.6552 - lr: 0.0010 - 238ms/epoch - 12ms/step
Epoch 35/100
20/20 - 0s - loss: 7.6629 - val_loss: 7.6429 - lr: 0.0010 - 228ms/epoch - 11ms/step
Epoch 36/100
20/20 - 0s - loss: 7.6516 - val_loss: 7.6417 - lr: 0.0010 - 255ms/epoch - 13ms/step
Epoch 37/100
20/20 - 0s - loss: 7.6431 - val_loss: 7.6364 - lr: 0.0010 - 225ms/epoch - 11ms/step
Epoch 38/100
20/20 - 0s - loss: 7.6307 - val_loss: 7.6275 - lr: 0.0010 - 247ms/epoch - 12ms/step
Epoch 39/100
20/20 - 0s - loss: 7.6272 - val_loss: 7.6251 - lr: 0.0010 - 223ms/epoch - 11ms/step
Epoch 40/100
20/20 - 0s - loss: 7.6203 - val_loss: 7.6248 - lr: 0.0010 - 245ms/epoch - 12ms/step
Epoch 41/100
20/20 - 0s - loss: 7.6124 - val_loss: 7.6197 - lr: 0.0010 - 227ms/epoch - 11ms/step
Epoch 42/100
20/20 - 0s - loss: 7.6041 - val_loss: 7.6115 - lr: 0.0010 - 252ms/epoch - 13ms/step
Epoch 43/100
20/20 - 0s - loss: 7.5983 - val_loss: 7.6161 - lr: 0.0010 - 256ms/epoch - 13ms/step
Epoch 44/100
20/20 - 0s - loss: 7.5938 - val_loss: 7.6148 - lr: 0.0010 - 255ms/epoch - 13ms/step
Epoch 45/100
20/20 - 0s - loss: 7.5911 - val_loss: 7.6101 - lr: 0.0010 - 231ms/epoch - 12ms/step
Epoch 46/100
20/20 - 0s - loss: 7.5844 - val_loss: 7.6212 - lr: 0.0010 - 239ms/epoch - 12ms/step
Epoch 47/100
20/20 - 0s - loss: 7.5758 - val_loss: 7.5985 - lr: 0.0010 - 257ms/epoch - 13ms/step
Epoch 48/100
20/20 - 0s - loss: 7.5707 - val_loss: 7.5896 - lr: 0.0010 - 241ms/epoch - 12ms/step
Epoch 49/100
20/20 - 0s - loss: 7.5657 - val_loss: 7.5936 - lr: 0.0010 - 233ms/epoch - 12ms/step
Epoch 50/100
20/20 - 0s - loss: 7.5604 - val_loss: 7.5947 - lr: 0.0010 - 250ms/epoch - 13ms/step
Epoch 51/100
20/20 - 0s - loss: 7.5541 - val_loss: 7.5824 - lr: 0.0010 - 222ms/epoch - 11ms/step
Epoch 52/100
20/20 - 0s - loss: 7.5524 - val_loss: 7.5966 - lr: 0.0010 - 242ms/epoch - 12ms/step
Epoch 53/100
20/20 - 0s - loss: 7.5484 - val_loss: 7.5904 - lr: 0.0010 - 220ms/epoch - 11ms/step
Epoch 54/100
20/20 - 0s - loss: 7.5406 - val_loss: 7.5793 - lr: 0.0010 - 226ms/epoch - 11ms/step
Epoch 55/100
20/20 - 0s - loss: 7.5354 - val_loss: 7.5854 - lr: 0.0010 - 222ms/epoch - 11ms/step
Epoch 56/100
20/20 - 0s - loss: 7.5305 - val_loss: 7.5783 - lr: 0.0010 - 220ms/epoch - 11ms/step
Epoch 57/100
20/20 - 0s - loss: 7.5269 - val_loss: 7.5747 - lr: 0.0010 - 251ms/epoch - 13ms/step
Epoch 58/100
20/20 - 0s - loss: 7.5224 - val_loss: 7.5723 - lr: 0.0010 - 234ms/epoch - 12ms/step
Epoch 59/100
20/20 - 0s - loss: 7.5217 - val_loss: 7.5684 - lr: 0.0010 - 256ms/epoch - 13ms/step
Epoch 60/100
20/20 - 0s - loss: 7.5125 - val_loss: 7.5695 - lr: 0.0010 - 226ms/epoch - 11ms/step
Epoch 61/100
20/20 - 0s - loss: 7.5098 - val_loss: 7.5752 - lr: 0.0010 - 246ms/epoch - 12ms/step
Epoch 62/100
20/20 - 0s - loss: 7.5106 - val_loss: 7.5723 - lr: 0.0010 - 224ms/epoch - 11ms/step
Epoch 63/100
20/20 - 0s - loss: 7.5063 - val_loss: 7.5720 - lr: 0.0010 - 245ms/epoch - 12ms/step
Epoch 64/100
20/20 - 0s - loss: 7.4991 - val_loss: 7.5553 - lr: 0.0010 - 225ms/epoch - 11ms/step
Epoch 65/100
20/20 - 0s - loss: 7.4974 - val_loss: 7.5715 - lr: 0.0010 - 244ms/epoch - 12ms/step
Epoch 66/100
20/20 - 0s - loss: 7.4936 - val_loss: 7.5572 - lr: 0.0010 - 267ms/epoch - 13ms/step
Epoch 67/100
20/20 - 0s - loss: 7.4896 - val_loss: 7.5673 - lr: 0.0010 - 263ms/epoch - 13ms/step
Epoch 68/100
20/20 - 0s - loss: 7.4917 - val_loss: 7.5676 - lr: 0.0010 - 246ms/epoch - 12ms/step
Epoch 69/100
20/20 - 0s - loss: 7.4899 - val_loss: 7.5510 - lr: 0.0010 - 241ms/epoch - 12ms/step
Epoch 70/100
20/20 - 0s - loss: 7.4820 - val_loss: 7.5672 - lr: 0.0010 - 240ms/epoch - 12ms/step
Epoch 71/100
20/20 - 0s - loss: 7.4838 - val_loss: 7.5649 - lr: 0.0010 - 234ms/epoch - 12ms/step
Epoch 72/100
20/20 - 0s - loss: 7.4784 - val_loss: 7.5387 - lr: 0.0010 - 222ms/epoch - 11ms/step
Epoch 73/100
20/20 - 0s - loss: 7.4747 - val_loss: 7.5469 - lr: 0.0010 - 264ms/epoch - 13ms/step
Epoch 74/100
20/20 - 0s - loss: 7.4719 - val_loss: 7.5364 - lr: 0.0010 - 230ms/epoch - 12ms/step
Epoch 75/100
20/20 - 0s - loss: 7.4706 - val_loss: 7.5429 - lr: 0.0010 - 238ms/epoch - 12ms/step
Epoch 76/100
20/20 - 0s - loss: 7.4687 - val_loss: 7.5497 - lr: 0.0010 - 245ms/epoch - 12ms/step
Epoch 77/100
20/20 - 0s - loss: 7.4676 - val_loss: 7.5424 - lr: 0.0010 - 257ms/epoch - 13ms/step
Epoch 78/100
20/20 - 0s - loss: 7.4587 - val_loss: 7.5279 - lr: 0.0010 - 243ms/epoch - 12ms/step
Epoch 79/100
20/20 - 0s - loss: 7.4566 - val_loss: 7.5306 - lr: 0.0010 - 257ms/epoch - 13ms/step
Epoch 80/100
20/20 - 0s - loss: 7.4610 - val_loss: 7.5343 - lr: 0.0010 - 259ms/epoch - 13ms/step
Epoch 81/100
20/20 - 0s - loss: 7.4550 - val_loss: 7.5238 - lr: 0.0010 - 283ms/epoch - 14ms/step
Epoch 82/100
20/20 - 0s - loss: 7.4551 - val_loss: 7.5239 - lr: 0.0010 - 269ms/epoch - 13ms/step
Epoch 83/100
20/20 - 0s - loss: 7.4543 - val_loss: 7.5313 - lr: 0.0010 - 356ms/epoch - 18ms/step
Epoch 84/100
20/20 - 0s - loss: 7.4469 - val_loss: 7.5200 - lr: 0.0010 - 269ms/epoch - 13ms/step
Epoch 85/100
20/20 - 0s - loss: 7.4451 - val_loss: 7.5127 - lr: 0.0010 - 326ms/epoch - 16ms/step
Epoch 86/100
20/20 - 0s - loss: 7.4447 - val_loss: 7.5209 - lr: 0.0010 - 296ms/epoch - 15ms/step
Epoch 87/100
20/20 - 0s - loss: 7.4401 - val_loss: 7.5237 - lr: 0.0010 - 255ms/epoch - 13ms/step
Epoch 88/100
20/20 - 0s - loss: 7.4392 - val_loss: 7.5105 - lr: 0.0010 - 422ms/epoch - 21ms/step
Epoch 89/100
20/20 - 0s - loss: 7.4383 - val_loss: 7.5114 - lr: 0.0010 - 269ms/epoch - 13ms/step
Epoch 90/100
20/20 - 0s - loss: 7.4348 - val_loss: 7.5108 - lr: 0.0010 - 243ms/epoch - 12ms/step
Epoch 91/100
20/20 - 0s - loss: 7.4319 - val_loss: 7.5058 - lr: 0.0010 - 239ms/epoch - 12ms/step
Epoch 92/100
20/20 - 0s - loss: 7.4302 - val_loss: 7.5086 - lr: 0.0010 - 274ms/epoch - 14ms/step
Epoch 93/100
20/20 - 0s - loss: 7.4276 - val_loss: 7.4989 - lr: 0.0010 - 275ms/epoch - 14ms/step
Epoch 94/100
20/20 - 0s - loss: 7.4282 - val_loss: 7.5105 - lr: 0.0010 - 265ms/epoch - 13ms/step
Epoch 95/100
20/20 - 0s - loss: 7.4265 - val_loss: 7.5018 - lr: 0.0010 - 248ms/epoch - 12ms/step
Epoch 96/100
20/20 - 0s - loss: 7.4229 - val_loss: 7.4923 - lr: 0.0010 - 263ms/epoch - 13ms/step
Epoch 97/100
20/20 - 0s - loss: 7.4208 - val_loss: 7.4985 - lr: 0.0010 - 272ms/epoch - 14ms/step
Epoch 98/100
20/20 - 0s - loss: 7.4175 - val_loss: 7.4952 - lr: 0.0010 - 367ms/epoch - 18ms/step
Epoch 99/100
20/20 - 0s - loss: 7.4174 - val_loss: 7.4923 - lr: 0.0010 - 269ms/epoch - 13ms/step
Epoch 100/100
20/20 - 0s - loss: 7.4180 - val_loss: 7.4968 - lr: 0.0010 - 282ms/epoch - 14ms/step
* Population optimizer:
    - best model is number 1
    - best loss function is 7.42
    - best validation loss function is 7.5
    - population losses [7.5]
In [4]:
# we can plot training summaries to make sure training went smoothly:
flow.training_plot()
No description has been provided for this image
In [5]:
# and we can print the training summary:
flow.print_training_summary()
loss         : 7.4180
val_loss     : 7.4968
lr           : 0.0010
chi2Z_ks     : 0.0352
chi2Z_ks_p   : 0.1644
loss_rate    : 6.1846e-04
val_loss_rate: 0.0044
In [6]:
# we can triangle plot the flow to see how well it has learned the target distribution:
g = plots.get_subplot_plotter()
g.triangle_plot([chain, flow.MCSamples(20000)], 
                params=flow.param_names,
                filled=True)
No description has been provided for this image
In [7]:
# this looks nice but not perfect, let's train for longer:
flow.feedback = 1
flow.train(epochs=300, verbose=-1);  # verbose = -1 uses tqdm progress bar
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
In [8]:
# we can plot training summaries to make sure training went smoothly:
flow.training_plot()
<Figure size 640x480 with 0 Axes>
No description has been provided for this image

If you train for long enough you should start seeing the learning rate adapting to the non-improving (noisy) loss function.

This means that the flow is learning finer and finer features and a good indication that training is converging. If you push it further, at some point, the flow will start overfitting and training will stop.

Now let's look at how the marginal distributions look like:

In [9]:
# we can triangle plot the flow to see how well it has learned the target distribution:
g = plots.get_subplot_plotter()
g.triangle_plot([chain, 
                 flow.MCSamples(20000)  # this flow method returns a MCSamples object
                 ], 
                params=flow.param_names,
                filled=True)
No description has been provided for this image

This is now much better!

We can use the trained flow to perform several operations. For example let's compute log-likelihoods

In [10]:
samples = flow.MCSamples(20000)
logP = flow.log_probability(flow.cast(samples.samples)).numpy()
samples.addDerived(logP, name='logP', label='\\log P')
samples.updateBaseStatistics();

# now let's plot everything:
g = plots.get_subplot_plotter()
g.triangle_plot([samples, chain], 
                plot_3d_with_param='logP',
                filled=False)
No description has been provided for this image

We can appreciate here a beautiful display of a projection effect. The marginal distribution of $p_5$ is peaked at a positive value while the logP plot clearly shows that the peak of the full distribution is the negative one.

If you are interested in understanding systematically these types of effect, check the corresponding tensiometer tutorial!

Average flow example:¶

A more advanced flow model consists in training several flows and using a weighted mixture normalizing flow model.

This flow model improves the variance of the flow in regions that are scarse with samples (as different flow models will allucinate differently)...

Let's try averaging 5 flow models (note that we could do this in parallel with MPI on bigger machines):

In [11]:
kwargs = {
          'feedback': 1,
          'verbose': -1,
          'plot_every': 1000,
          'pop_size': 1,
          'num_flows': 5,
          'epochs': 400,
        }

average_flow = sp.average_flow_from_chain(chain,  # parameter difference chain
                                                                         **kwargs)
Warning: MPI is incompatible with no cache. Disabling MPI.
Training flow 0
0epoch [00:00, ?epoch/s]
Training flow 1
0epoch [00:00, ?epoch/s]
Training flow 2
0epoch [00:00, ?epoch/s]
Training flow 3
0epoch [00:00, ?epoch/s]
Training flow 4
0epoch [00:00, ?epoch/s]
In [12]:
# most methods are implemented for the average flow as well:
average_flow.training_plot()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [13]:
# and we can print the training summary, which in this case contains more info:
average_flow.print_training_summary()
Number of flows: 5
Flow weights   : [0.2  0.19 0.2  0.2  0.21]
loss         : [7.26 7.28 7.3  7.25 7.25]
val_loss     : [7.25 7.26 7.24 7.21 7.2 ]
lr           : [3.16e-06 3.16e-05 1.00e-06 3.16e-05 1.00e-04]
chi2Z_ks     : [0.04 0.05 0.05 0.03 0.04]
chi2Z_ks_p   : [0.06 0.02 0.   0.17 0.15]
loss_rate    : [-1.00e-05 -1.10e-04 -4.82e-05  1.19e-04 -3.32e-04]
val_loss_rate: [-5.25e-06 -4.37e-04 -2.77e-05 -1.69e-03 -1.18e-03]
In [14]:
avg_samples = average_flow.MCSamples(20000)
avg_samples.name_tag = 'Average Flow'
temp_samples = [_f.MCSamples(20000) for _f in average_flow.flows]
for i, _s in enumerate(temp_samples):
    _s.name_tag = _s.name_tag + f'_{i}'
# let's plot the flows:
g = plots.get_subplot_plotter()
g.triangle_plot([chain, avg_samples] + temp_samples,
                filled=False)
WARNING:tensorflow:5 out of the last 8 calls to <function FlowCallback.log_probability at 0x1d21fbc10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 8 calls to <function FlowCallback.log_probability at 0x1d21fbc10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 9 calls to <function FlowCallback.log_probability at 0x1d34d2670> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 9 calls to <function FlowCallback.log_probability at 0x1d34d2670> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
No description has been provided for this image
In [15]:
logP = average_flow.log_probability(average_flow.cast(avg_samples.samples)).numpy()
avg_samples.addDerived(logP, name='logP', label='\\log P')
avg_samples.updateBaseStatistics();

# now let's plot everything:
g = plots.get_subplot_plotter()
g.triangle_plot([avg_samples, chain], 
                plot_3d_with_param='logP',
                filled=False)
No description has been provided for this image

Real world application: joint parameter estimation¶

In this example we perform a flow-based analysis of a joint posterior.

The idea is that we have posteriors samples from two independent experiments, we learn the two posteriors and then we combine them to form the joint posterior.

Note that we are assuming - as it is true in this example - that the prior is the same among the two experiments and flat (so that we are not duplicating the prior).

This procedure was used, for example, in Gatti, Campailla et al (2024), arXiv:2405.10881.

In [16]:
# we start by loading up the posteriors:

# load the samples (remove no burn in since the example chains have already been cleaned):
chains_dir = './../../test_chains/'
# the Planck 2018 TTTEEE chain:
chain_1 = getdist.mcsamples.loadMCSamples(file_root=chains_dir+'Planck18TTTEEE', no_cache=True, settings=getdist_settings)
# the DES Y1 3x2 chain:
chain_2 = getdist.mcsamples.loadMCSamples(file_root=chains_dir+'DES', no_cache=True, settings=getdist_settings)
# the joint chain:
chain_12 = getdist.mcsamples.loadMCSamples(file_root=chains_dir+'Planck18TTTEEE_DES', no_cache=True, settings=getdist_settings)

# let's add omegab as a derived parameter:
for _ch in [chain_1, chain_2, chain_12]:
    _p = _ch.getParams()
    _h = _p.H0 / 100.
    _ch.addDerived(_p.omegabh2 / _h**2, name='omegab', label='\\Omega_b')
    _ch.updateBaseStatistics()

# we define the parameters of the problem:
param_names = ['H0', 'omegam', 'sigma8', 'ns', 'omegab']

# and then do a sanity check plot:
g = plots.get_subplot_plotter()
g.triangle_plot([chain_1, chain_2, chain_12], params=param_names, filled=True)
No description has been provided for this image
In [17]:
# we then train the flows on the base parameters that we want to combine (note that for this exercise we should include all shared parameters):
kwargs = {
          'feedback': 1,
          'verbose': -1,
          'plot_every': 1000,
          'pop_size': 1,
          'num_flows': 3,
          'epochs': 400,
        }

# actual flow training:
flow_1 = sp.average_flow_from_chain(chain_1, param_names=param_names, **kwargs)
flow_2 = sp.average_flow_from_chain(chain_2, param_names=param_names, **kwargs)
flow_12 = sp.average_flow_from_chain(chain_12, param_names=param_names, **kwargs)

# plot to make sure training went well:
flow_1.training_plot()
flow_2.training_plot()
flow_12.training_plot()
Warning: MPI is incompatible with no cache. Disabling MPI.
Training flow 0
0epoch [00:00, ?epoch/s]
Training flow 1
0epoch [00:00, ?epoch/s]
Training flow 2
0epoch [00:00, ?epoch/s]
Warning: MPI is incompatible with no cache. Disabling MPI.
Training flow 0
0epoch [00:00, ?epoch/s]
Training flow 1
0epoch [00:00, ?epoch/s]
Training flow 2
0epoch [00:00, ?epoch/s]
Warning: MPI is incompatible with no cache. Disabling MPI.
Training flow 0
0epoch [00:00, ?epoch/s]
Training flow 1
0epoch [00:00, ?epoch/s]
Training flow 2
0epoch [00:00, ?epoch/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [18]:
# sanity check triangle plot:
g = plots.get_subplot_plotter()
g.triangle_plot([chain_1, flow_1.MCSamples(20000, settings=getdist_settings),
                 chain_2, flow_2.MCSamples(20000, settings=getdist_settings),
                 chain_12, flow_12.MCSamples(20000, settings=getdist_settings),
                 ], 
                params=param_names,
                filled=False)
# we log scale the y axis for the logP plot so that we can appreciate the accuracy of the flow on the tails:
for i in range(len(param_names)):
    _ax = g.subplots[i, i]
    _ax.set_yscale('log')
    _ax.set_ylim([1.e-5, 1.0])
    _ax.set_ylabel('$\\log P$')
    _ax.tick_params(axis='y', which='both', labelright='on')
    _ax.yaxis.set_label_position("right")    
No description has been provided for this image
In [19]:
# now we can define the joint posterior:
def joint_log_posterior(H0, omegam, sigma8, ns, omegab):
    params = [H0, omegam, sigma8, ns, omegab]
    return [flow_1.log_probability(flow_1.cast(params)).numpy() + flow_2.log_probability(flow_2.cast(params)).numpy()]

# and sample it:
from cobaya.run import run
from getdist.mcsamples import MCSamplesFromCobaya

parameters = {}
for key in param_names:
    parameters[key] = {"prior": {"min": 1.01*max(flow_1.parameter_ranges[key][0], flow_2.parameter_ranges[key][0]),
                                 "max": 0.99*min(flow_1.parameter_ranges[key][1], flow_2.parameter_ranges[key][1])},
                       "latex": flow_1.param_labels[flow_1.param_names.index(key)]}
info = {
    "likelihood": {"joint_log_posterior": joint_log_posterior},
    "params": parameters,
    }
In [20]:
# MCMC sample:

# we need a \sim good initial proposal and starting point, we get them from one of the flows:
flow_1_samples = flow_1.sample(10000)
flow_1_logPs = flow_1.log_probability(flow_1_samples).numpy()
flow_1_maxP_sample = flow_1_samples[np.argmax(flow_1_logPs)].numpy()

# we need a good starting point otherwise this will take long...
for _i, _k in enumerate(parameters.keys()):
    info['params'][_k]['ref'] = flow_1_maxP_sample[_i]

info["sampler"] = {"mcmc": 
                {'covmat': np.cov(flow_1_samples.numpy().T),
                 'covmat_params': param_names,
                 'max_tries': np.inf,
                 'Rminus1_stop': 0.01,
                 'learn_proposal_Rminus1_max': 30.,
                 'learn_proposal_Rminus1_max_early': 30.,
                 'measure_speeds': False,
                 'Rminus1_single_split': 10,
                 }}
info['debug'] = 100  # note this is an insane hack to disable very verbose output...
updated_info, sampler = run(info)
joint_chain = MCSamplesFromCobaya(updated_info, sampler.products()["sample"], ignore_rows=0.3, settings=getdist_settings)
In [21]:
## Nested sampling sample:
#_dim = len(flow_1.param_names)
#
#info["sampler"] = {"polychord": {'nlive': 50*_dim,
#                                 'measure_speeds': False,
#                                 'num_repeats': 2*_dim,
#                                 'nprior': 10*25*_dim,
#                                 'do_clustering': True,
#                                 'precision_criterion': 0.01,
#                                 'boost_posterior': 10, 
#                                 'feedback': 0,
#                                 },
#                    }
#info['debug'] = 100  # note this is an insane hack to disable very verbose output...
#updated_info, sampler = run(info)
#joint_chain = MCSamplesFromCobaya(updated_info, sampler.products()["sample"], settings=getdist_settings)
In [22]:
joint_chain.name_tag = 'Flow Joint'
chain_12.name_tag = 'Real Joint (Planck + DES)'

# sanity check triangle plot:
g = plots.get_subplot_plotter()
g.triangle_plot([joint_chain, chain_12], 
                params=param_names,
                filled=False)
No description has been provided for this image

As we can see this works fairly well, given that the two experiments are in some tension - do not overlap significantly.

Make sure you check for the consistency of the experiments you are combining before doing so, to ensure that the joint flow posterior samples a well-trained part of the flows.

You can check the example notebook in this documentation for how to compute tensions between two experiments.

Advanced Topic: accurate likelihood values¶

For some applications we need to push the local accuracy of the flow model. In this case we need to provide exact probability values (up to normalization constant) for the training set.

These are then used to build a part of the loss function that rewards local accuracy of probability values. This second part of the loss function is the estimated evidence error. By default the code adaptively mixes the two loss functions to find an optimal solution.

As a downside we can only train a flow that preserves all the parameters of the distribution, i.e. we cannot train on marginalized parameters (as we have done in the previous examples).

For more details see

In [23]:
ev, eer = flow.evidence()
print(f'log(Z) = {ev} +- {eer}')
log(Z) = 0.10397262871265411 +- 0.7186664938926697

We can see that the value is close to what it should be (zero since the original distribution is normalized) but the estimated error is still fairly high.

Since we have (normalized) log P values we can check the local reliability of the normalizing flow:

In [24]:
validation_flow_log10_P = flow.log_probability(flow.cast(chain.samples[flow.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[flow.test_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = flow.log_probability(flow.cast(chain.samples[flow.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[flow.training_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist

# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)

ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])

plt.tight_layout()
plt.show()
No description has been provided for this image

We can clearly see that the local accuracy of the flow in full dimension is not high. As we move to the tails we easily have large errors. The variance of this plot is the estimated error on the evidence, which is rather large and dominated by the outliers in the tails.

Considering average flows usually improves the situation, in particular on the validation sample.

In [25]:
ev, eer = average_flow.evidence()
print(f'log(Z) = {ev} +- {eer}')
log(Z) = 0.07254701852798462 +- 0.48980656266212463
In [26]:
validation_flow_log10_P = average_flow.log_probability(average_flow.cast(chain.samples[average_flow.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[average_flow.test_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = average_flow.log_probability(average_flow.cast(chain.samples[average_flow.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[average_flow.training_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist

# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)

ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])

plt.tight_layout()
plt.show()
No description has been provided for this image

This looks significantly better, and in fact the error on the evidence estimate is lower...

If we want to do better we need to train with evidence error loss, as discussed in the reference paper for this example notebook.

In [27]:
kwargs = {
          'feedback': 1,
          'verbose': -1,
          'plot_every': 1000,
          'pop_size': 1,
          'num_flows': 1,
          'epochs': 400,
          'loss_mode': 'softadapt',
        }

average_flow_2 = sp.average_flow_from_chain(chain,  # parameter difference chain
                                            **kwargs)
Warning: MPI is incompatible with no cache. Disabling MPI.
Training flow 0
0epoch [00:00, ?epoch/s]
In [28]:
average_flow_2.training_plot()
No description has been provided for this image

As we can see the training plots are substantially more complicated as we are monitoring several additional quantities.

In [29]:
ev, eer = average_flow_2.evidence()
print(f'log(Z) = {ev} +- {eer}')
log(Z) = 0.12380771338939667 +- 0.4636903405189514
In [30]:
validation_flow_log10_P = average_flow_2.log_probability(average_flow_2.cast(chain.samples[average_flow_2.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[average_flow_2.test_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = average_flow_2.log_probability(average_flow_2.cast(chain.samples[average_flow_2.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[average_flow_2.training_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist

# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)

ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])

plt.tight_layout()
plt.show()
No description has been provided for this image

As we can see this achieves performances that are very close to averaging flows. Combining the two strategies achieves the best performances (but is slower to train).

Advanced Topic: Spline Flows¶

When more flexibility in the normalizing flow model is needed we provide an implementation of neural spline flows as discussed in Durkan et al (2019), arXiv:1906.04032.

In [31]:
kwargs = {
          # flow settings:
          'pop_size': 1,
          'num_flows': 1,
          'epochs': 400,
          'transformation_type': 'spline',
          'autoregressive_type': 'masked',
          # feedback flags:
          'feedback': 1,
          'verbose': -1,
          'plot_every': 1000,
        }

spline_flow = sp.flow_from_chain(chain,  # parameter difference chain
                                 **kwargs)
* Initializing samples
    - time taken: 0.0010 seconds
* Initializing fixed bijector
    - time taken: 0.1885 seconds
* Initializing trainable bijector
WARNING: range_max should be larger than the maximum range of the data and is beeing adjusted.
    range_max: 5.0
    max range: 10.56929874420166
    new range_max: 11.569299
    - time taken: 1.3444 seconds
* Initializing training dataset
    - time taken: 1.1707 seconds
* Initializing transformed distribution
    - time taken: 0.0083 seconds
* Initializing loss function
    - time taken: 0.0000 seconds
* Initializing training model
    - Compiling model
    - time taken: 0.0283 seconds
    - time taken: 14.7903 seconds
* Training
    - Compiling model
    - time taken: 0.0270 seconds
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
In [32]:
# we can plot training summaries to make sure training went smoothly:
spline_flow.training_plot()
No description has been provided for this image
In [33]:
# we can triangle plot the flow to see how well it has learned the target distribution:
g = plots.get_subplot_plotter()
g.triangle_plot([chain, 
                 spline_flow.MCSamples(20000)  # this flow method returns a MCSamples object
                 ], 
                params=flow.param_names,
                filled=True)
No description has been provided for this image
In [34]:
samples = spline_flow.MCSamples(20000)
logP = spline_flow.log_probability(spline_flow.cast(samples.samples)).numpy()
samples.addDerived(logP, name='logP', label='\\log P')
samples.updateBaseStatistics();

# now let's plot everything:
g = plots.get_subplot_plotter()
g.triangle_plot([samples, chain], 
                plot_3d_with_param='logP',
                filled=False)
No description has been provided for this image
In [35]:
validation_flow_log10_P = spline_flow.log_probability(spline_flow.cast(chain.samples[spline_flow.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[spline_flow.test_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = spline_flow.log_probability(spline_flow.cast(chain.samples[spline_flow.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[spline_flow.training_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist

ev, eer = spline_flow.evidence()
print(f'log(Z) = {ev} +- {eer}')

# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)

ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])

plt.tight_layout()
plt.show()
log(Z) = 0.07555282115936279 +- 0.7554653286933899
No description has been provided for this image

We can check what happens across the bijector layers:

In [36]:
from tensiometer.synthetic_probability import flow_utilities as flow_utils

training_samples_spaces, validation_samples_spaces = \
    flow_utils.get_samples_bijectors(spline_flow, 
                                     feedback=True)
    
for i, _s in enumerate(training_samples_spaces):
    print('*  ', _s.name_tag)
    g = plots.get_subplot_plotter()
    g.triangle_plot([
                    training_samples_spaces[i],
                    validation_samples_spaces[i]], 
                    filled=True,
                    )
    plt.show()
0 - bijector name:  permute
1 - bijector name:  spline_flow
2 - bijector name:  permute
3 - bijector name:  spline_flow
4 - bijector name:  permute
5 - bijector name:  spline_flow
6 - bijector name:  permute
7 - bijector name:  spline_flow
8 - bijector name:  permute
9 - bijector name:  spline_flow
10 - bijector name:  permute
11 - bijector name:  spline_flow
12 - bijector name:  permute
13 - bijector name:  spline_flow
14 - bijector name:  permute
15 - bijector name:  spline_flow
*   original_space
No description has been provided for this image
*   training_space
No description has been provided for this image
*   0_after_permute
No description has been provided for this image
*   1_after_spline_flow
No description has been provided for this image
*   2_after_permute
No description has been provided for this image
*   3_after_spline_flow
No description has been provided for this image
*   4_after_permute
No description has been provided for this image
*   5_after_spline_flow
No description has been provided for this image
*   6_after_permute
No description has been provided for this image
*   7_after_spline_flow
No description has been provided for this image
*   8_after_permute
No description has been provided for this image
*   9_after_spline_flow
No description has been provided for this image
*   10_after_permute
No description has been provided for this image
*   11_after_spline_flow
No description has been provided for this image
*   12_after_permute
No description has been provided for this image
*   13_after_spline_flow
No description has been provided for this image
*   14_after_permute
No description has been provided for this image
*   15_after_spline_flow
No description has been provided for this image